import itertools
from HPO.hpo_logger import HPOLogger
from models.utils.continual_model import ContinualModel
from torch.utils.data import Dataset
from datasets.utils.continual_dataset import ContinualDataset
from argparse import Namespace


class Hyperparam:

    def __init__(self, name, getter, setter, val_list):
        self.name = name
        self._get = getter
        self._set = setter
        self.val_list = val_list

    @property
    def value(self):
        return self._get()

    @value.setter
    def value(self, value):
        self._set(value)


class BaseGridHPO:

    def __init__(self):
        self.hyperparams = {}

    def register_hyperparam(self, name, getter, setter, val_list):
        if name in self.hyperparams:
            print(name + " has already been registered for HPO")
            raise ValueError
        self.hyperparams[name] = Hyperparam(name, getter, setter, val_list)

    def unregister_hyperparam(self, name):
        if name in self.hyperparams:
            del self.hyperparams[name]
        else:
            print(name + " has not been registered for HPO so cannot be unregistered")
            raise ValueError

    def set_val_list_for_hyperparam(self, name, new_val_list):
        if name in self.hyperparams:
            self.hyperparams[name].val_list = new_val_list
        else:
            print("Warning: " + name + " is not a registered hyperparam so can not be set")

    def grid(self):
        val_lists = [self.hyperparams[name].val_list for name in self.hyperparams]
        grid = itertools.product(*val_lists)
        for setting in grid:
            yield {name: setting[i] for i, name in enumerate(self.hyperparams)}

    def set_hyperparams(self, setting):
        for name in setting:
            if name in self.hyperparams:
                self.hyperparams[name].value = setting[name]
            else:
                print("Warning: "+name+" is not a registered hyperparam so can not be set")

    def select_hyperparams(self, dataset: Dataset, model: ContinualModel, logger: HPOLogger,
                           data_stream: ContinualDataset, args: Namespace, task_id: int) -> None:
        raise NotImplementedError

